{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "85b9cad9-14cc-4145-a71a-6fdfa7b9b044",
   "metadata": {},
   "source": [
    "# BLEU Score for Unigrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d152300f-00fd-4f31-bb2c-a776f1822a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "original = \"Der schnelle braune Fuchs sprang ueber den faulen Hund\"\n",
    "\n",
    "reference =   \"The quick brown fox jumped over the lazy dog\"\n",
    "candidate_1 = \"The fast  brown fox leaped over the      dog\"\n",
    "candidate_2 = \"The swift brown fox jumped over the lazy dog\"\n",
    "candidate_3 = \"The swift tawny fox leaped over the indolent canine.\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3677a7a3-d4d1-49a0-ab90-d02b526cde04",
   "metadata": {},
   "source": [
    "### NLTK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a159771-e2ee-4adc-b53d-d349d69a6a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install nltk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cb45c369-e04f-43ea-a31f-e2380efa3e79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.66\n",
      "BLEU score for example 2: 0.89\n",
      "BLEU score for example 3: 0.44\n"
     ]
    }
   ],
   "source": [
    "from nltk.translate.bleu_score import sentence_bleu\n",
    "\n",
    "bleu_nltk_1 = sentence_bleu([reference.split()], candidate_1.split(), weights=[1.])\n",
    "bleu_nltk_2 = sentence_bleu([reference.split()], candidate_2.split(), weights=[1.])\n",
    "bleu_nltk_3 = sentence_bleu([reference.split()], candidate_3.split(), weights=[1.])\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_nltk_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_nltk_2:.2f}\")\n",
    "print(f\"BLEU score for example 3: {bleu_nltk_3:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d89def98-71e9-4ee4-96ac-017dd9bf2a28",
   "metadata": {},
   "source": [
    "### TorchMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "01198d15-c16c-4aa6-a8ed-65705ab743cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.66\n",
      "BLEU score for example 2: 0.89\n",
      "BLEU score for example 3: 0.44\n"
     ]
    }
   ],
   "source": [
    "from torchmetrics import BLEUScore\n",
    "\n",
    "bleu = BLEUScore(n_gram=1)\n",
    "\n",
    "# Calculate BLEU scores\n",
    "bleu_tm_1 = bleu(target=[[reference]], preds=[candidate_1])\n",
    "bleu_tm_2 = bleu(target=[[reference]], preds=[candidate_2])\n",
    "bleu_tm_3 = bleu(target=[[reference]], preds=[candidate_3])\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_tm_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_tm_2:.2f}\")\n",
    "print(f\"BLEU score for example 3: {bleu_tm_3:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe6a608-168a-4f7c-87a0-d82b9eac21ac",
   "metadata": {},
   "source": [
    "### From Scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "66075150-bb85-465f-945e-3cc008de7efd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.66\n",
      "BLEU score for example 2: 0.89\n",
      "BLEU score for example 3: 0.44\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from collections import Counter\n",
    "\n",
    "def ngrams(sentence, n):\n",
    "    return [tuple(sentence[i:i+n]) for i in range(len(sentence)-n+1)]\n",
    "\n",
    "def modified_precision(reference, candidate, n):\n",
    "    ref_ngrams = Counter(ngrams(reference, n))\n",
    "    cand_ngrams = Counter(ngrams(candidate, n))\n",
    "\n",
    "    count_clip = sum(min(cand_ngrams[ng], ref_ngrams[ng]) for ng in cand_ngrams)\n",
    "    count_total = sum(cand_ngrams.values())\n",
    "\n",
    "    return count_clip / count_total if count_total > 0 else 0\n",
    "\n",
    "def brevity_penalty(reference, candidate):\n",
    "    ref_len = len(reference)\n",
    "    cand_len = len(candidate)\n",
    "\n",
    "    if cand_len > ref_len:\n",
    "        return 1\n",
    "    elif cand_len == 0:\n",
    "        return 0\n",
    "    else:\n",
    "        return math.exp(1 - ref_len / cand_len)\n",
    "\n",
    "def bleu_score_unigram(reference, candidate):\n",
    "    bp = brevity_penalty(reference, candidate)\n",
    "    precision = modified_precision(reference, candidate, n=1)\n",
    "\n",
    "    return bp * precision\n",
    "\n",
    "\n",
    "bleu_scratch_1 = bleu_score_unigram(reference=reference.split(), candidate=candidate_1.split())\n",
    "bleu_scratch_2 = bleu_score_unigram(reference=reference.split(), candidate=candidate_2.split())\n",
    "bleu_scratch_3 = bleu_score_unigram(reference=reference.split(), candidate=candidate_3.split())\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_scratch_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_scratch_2:.2f}\")\n",
    "print(f\"BLEU score for example 3: {bleu_scratch_3:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7b3782e-704e-40aa-8dee-062282ba1e37",
   "metadata": {},
   "source": [
    "# BLEU Score for 4-grams (\"default\" BLEU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cf2bb62a-e1fd-4d5b-8d1d-3beadb72395b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example 1\n",
    "candidate_1 = \"The quick brown dog jumps over the lazy fox\"\n",
    "references_1 = [\n",
    "    \"The quick brown fox jumps over the lazy dog\",\n",
    "    \"The fast brown fox leaps over the lazy dog\",\n",
    "]\n",
    "\n",
    "# Example 2\n",
    "candidate_2 = \"The small red car drives quickly down the road\"\n",
    "references_2 = [\n",
    "    \"The small red car races quickly along the road\",\n",
    "    \"A small red car speeds rapidly down the avenue\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "288de8e8-a79a-4b8c-9822-be523a2bcbd3",
   "metadata": {},
   "source": [
    "## NLTK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2206ae02-71b4-44f8-8c9c-d939d1239e78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.46\n",
      "BLEU score for example 2: 0.40\n"
     ]
    }
   ],
   "source": [
    "from nltk.translate.bleu_score import sentence_bleu\n",
    "\n",
    "bleu_nltk_1 = sentence_bleu([r.split() for r in references_1], candidate_1.split())\n",
    "bleu_nltk_2 = sentence_bleu([r.split() for r in references_2], candidate_2.split())\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_nltk_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_nltk_2:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d06ea55-af4a-4a47-bb67-4ed43f909b0b",
   "metadata": {},
   "source": [
    "## TorchMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "363b3809-8745-4937-8452-9e3c67c95d27",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.46\n",
      "BLEU score for example 2: 0.40\n"
     ]
    }
   ],
   "source": [
    "from torchmetrics import BLEUScore\n",
    "\n",
    "bleu = BLEUScore(n_gram=4)\n",
    "\n",
    "# Calculate BLEU scores\n",
    "bleu_tm_1 = bleu(target=[references_1], preds=[candidate_1])\n",
    "bleu_tm_2 = bleu(target=[references_2], preds=[candidate_2])\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_tm_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_tm_2:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fd1a273-3f82-43a9-8cd8-e56db1df1f6f",
   "metadata": {},
   "source": [
    "## From Scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "910e0742-68a4-452b-94c4-09eec1daf680",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU score for example 1: 0.46\n",
      "BLEU score for example 2: 0.40\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from collections import Counter\n",
    "from fractions import Fraction\n",
    "\n",
    "def tokenize(sentence):\n",
    "    return sentence.lower().split()\n",
    "\n",
    "def ngrams(tokens, n):\n",
    "    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]\n",
    "\n",
    "def modified_precision(candidate, references, n):\n",
    "    candidate_ngrams = Counter(ngrams(candidate, n))\n",
    "    max_reference_counts = Counter()\n",
    "\n",
    "    for reference in references:\n",
    "        reference_ngrams = Counter(ngrams(reference, n))\n",
    "        for ngram in candidate_ngrams:\n",
    "            max_reference_counts[ngram] = max(max_reference_counts[ngram], reference_ngrams[ngram])\n",
    "\n",
    "    clipped_counts = {\n",
    "        ngram: min(count, max_reference_counts[ngram])\n",
    "        for ngram, count in candidate_ngrams.items()\n",
    "    }\n",
    "\n",
    "    numerator = sum(clipped_counts.values())\n",
    "    denominator = sum(candidate_ngrams.values())\n",
    "\n",
    "    if denominator == 0:\n",
    "        return 0\n",
    "    return Fraction(numerator, denominator)\n",
    "\n",
    "def closest_reference_length(candidate, references):\n",
    "    ref_lens = [len(reference) for reference in references]\n",
    "    candidate_len = len(candidate)\n",
    "    closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - candidate_len), ref_len))\n",
    "    return closest_ref_len\n",
    "\n",
    "def brevity_penalty(candidate, references):\n",
    "    candidate_length = len(candidate)\n",
    "    closest_ref_len = closest_reference_length(candidate, references)\n",
    "\n",
    "    if candidate_length > closest_ref_len:\n",
    "        return 1\n",
    "    else:\n",
    "        return math.exp(1 - closest_ref_len / candidate_length)\n",
    "\n",
    "def sentence_bleu_scratch(candidate, references, weights=(0.25, 0.25, 0.25, 0.25)):\n",
    "    candidate_tokens = tokenize(candidate)\n",
    "    reference_tokens = [tokenize(reference) for reference in references]\n",
    "\n",
    "    precisions = [\n",
    "        modified_precision(candidate_tokens, reference_tokens, n+1)\n",
    "        for n in range(len(weights))\n",
    "    ]\n",
    "\n",
    "    if all(p == 0 for p in precisions):\n",
    "        return 0\n",
    "\n",
    "    precision_product = math.exp(\n",
    "        sum(w * math.log(float(p)) for w, p in zip(weights, precisions) if p != 0)\n",
    "    )\n",
    "    bp = brevity_penalty(candidate_tokens, reference_tokens)\n",
    "    bleu = bp * precision_product\n",
    "\n",
    "    return min(bleu, 1)  # Ensure the BLEU score is between 0 and 1\n",
    "\n",
    "\n",
    "bleu_score_scratch_1 = sentence_bleu_scratch(candidate_1, references_1)\n",
    "bleu_score_scratch_2 = sentence_bleu_scratch(candidate_2, references_2)\n",
    "\n",
    "print(f\"BLEU score for example 1: {bleu_score_scratch_1:.2f}\")\n",
    "print(f\"BLEU score for example 2: {bleu_score_scratch_2:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
